{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 特征设计与模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from utils import visualization_evaluation,visualization_y\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# some parameters\n",
    "tran_size = 20000\n",
    "\n",
    "blk_sz, sensitivity = 8, 32\n",
    "selected_bands = [127, 201, 202, 294]\n",
    "tree_num = 185\n",
    "pic_row, pic_col= 600, 1024\n",
    "\n",
    "dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_4.p'\n",
    "model_file = f'./models/rf_{blk_sz}x{blk_sz}_c{len(selected_bands)}_{tree_num}_sen{sensitivity}_4.model'"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 数据集与样本平衡"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 读取数据\n",
    "with open(dataset_file, 'rb') as f:\n",
    "    x_list, y_list = pickle.load(f)\n",
    "# 确保数据当中x和y数量对得上\n",
    "assert len(x_list) == len(y_list)\n",
    "print(\"数据量： \", len(x_list))\n",
    "x, y = np.asarray(x_list), np.asarray(y_list, dtype=int)\n",
    "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=5,\n",
    "                                                    shuffle=True, stratify=y)\n",
    "print(f\"x train {x_train.shape}, y train {y_train.shape}\\n\"\n",
    "      f\"x test {x_test.shape}, y test {y_test.shape}\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 3, figsize=(12, 6))\n",
    "hist_res_total = axs[0].hist(y,[0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
    "hist_res_train = axs[1].hist(y_train,[0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
    "hist_res_test = axs[2].hist(y_test,[0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
    "print(f'total {hist_res_total} \\n'\n",
    "      f'train {hist_res_train} \\n'\n",
    "      f'test {hist_res_test}')\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 发现样本实在是不平衡\n",
    "from imblearn.over_sampling import RandomOverSampler\n",
    "ros = RandomOverSampler(random_state=0)\n",
    "x_train_shape = x_train.shape\n",
    "x_train = x_train.reshape((x_train.shape[0], -1))\n",
    "x_resampled, y_resampled = ros.fit_resample(x_train, y_train)\n",
    "# 画图\n",
    "fig, axs = plt.subplots(figsize=(10, 6))\n",
    "hist_res_train = axs.hist(y_resampled, [0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
    "print(f'train {hist_res_train} \\n')\n",
    "plt.show()\n",
    "# 抽样\n",
    "x_train, _, y_train, _ =  train_test_split(x_resampled, y_resampled, train_size=tran_size, random_state=0,\n",
    "                                           shuffle=True, stratify=y_resampled)\n",
    "# 画图\n",
    "fig, axs = plt.subplots(figsize=(10, 6))\n",
    "hist_res_train = axs.hist(y_train, [0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
    "print(f'train {hist_res_train} \\n')\n",
    "plt.show()\n",
    "x_train = x_train.reshape(x_train.shape[0], x_train_shape[1], x_train_shape[2], x_train_shape[3])\n",
    "print(len(x_train))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 特征设计\n",
    "设计特征"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from models import feature"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "features_train = feature(x_train)\n",
    "features_test = feature(x_test)\n",
    "print(features_test.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 进行训练，验证特征好坏"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from models import train_rf_and_report, SpecDetector\n",
    "feature_x = feature(x)\n",
    "clf = train_rf_and_report(features_train, y_train, feature_x, y, tree_num, save_path=model_file)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))\n",
    "# 画图验证效果\n",
    "visualization_evaluation(detector=model, data_path=r'F:\\zhouchao\\617\\617_pure_tobacco', selected_bands=selected_bands)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "visualization_y(y_list, blk_sz)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "7f619fc91ee8bdab81d49e7c14228037474662e3f2d607687ae505108922fa06"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}